# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import inspect

import gin
import tensorflow as tf

from official.nlp.modeling import layers

# 编码器脚手架：

@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(tf.keras.Model):
    """Bi-directional Transformer-based encoder network scaffold.
  
    This network allows users to flexibly implement an encoder similar to the one
    described in "BERT: Pre-training of Deep Bidirectional Transformers for
    Language Understanding" (https://arxiv.org/abs/1810.04805).
  
    In this network, users can choose to provide a custom embedding subnetwork
    (which will replace the standard embedding logic) and/or a custom hidden layer
    class (which will replace the Transformer instantiation in the encoder). For
    each of these custom injection points, users can pass either a class or a
    class instance. If a class is passed, that class will be instantiated using
    the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance
    is passed, that instance will be invoked. (In the case of hidden_cls, the
    instance will be invoked 'num_hidden_instances' times.
  
    If the hidden_cls is not overridden, a default transformer layer will be
    instantiated.
  
    Arguments:
      pooled_output_dim: The dimension of pooled output.
      pooler_layer_initializer: The initializer for the classification
        layer.
      embedding_cls: The class or instance to use to embed the input data. This
        class or instance defines the inputs to this encoder and outputs
        (1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
        (2) attention masking with tensor [batch_size, seq_length, seq_length].
        If embedding_cls is not set, a default embedding network
        (from the original BERT paper) will be created.
      embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
        be instantiated. If embedding_cls is not set, a config dict must be
        passed to 'embedding_cfg' with the following values:
        "vocab_size": The size of the token vocabulary.
        "type_vocab_size": The size of the type vocabulary.
        "hidden_size": The hidden size for this encoder.
        "max_seq_length": The maximum sequence length for this encoder.
        "seq_length": The sequence length for this encoder.
        "initializer": The initializer for the embedding portion of this encoder.
        "dropout_rate": The dropout rate to apply before the encoding layers.
      embedding_data: A reference to the embedding weights that will be used to
        train the masked language model, if necessary. This is optional, and only
        needed if (1) you are overriding embedding_cls and (2) are doing standard
        pretraining.
      num_hidden_instances: The number of times to instantiate and/or invoke the
        hidden_cls.
      hidden_cls: The class or instance to encode the input data. If hidden_cls is
        not set, a KerasBERT transformer layer will be used as the encoder class.
      hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
        instantiated. If hidden_cls is not set, a config dict must be passed to
        'hidden_cfg' with the following values:
          "num_attention_heads": The number of attention heads. The hidden size
            must be divisible by num_attention_heads.
          "intermediate_size": The intermediate size of the transformer.
          "intermediate_activation": The activation to apply in the transfomer.
          "dropout_rate": The overall dropout rate for the transformer layers.
          "attention_dropout_rate": The dropout rate for the attention layers.
          "kernel_initializer": The initializer for the transformer layers.
      return_all_layer_outputs: Whether to output sequence embedding outputs of
        all encoder transformer layers.
    """
    
    def __init__(
            self,
            pooled_output_dim,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            embedding_cls=None,              # 指定 嵌入 层
            embedding_cfg=None,
            embedding_data=None,
            num_hidden_instances=1,
            hidden_cls=layers.Transformer,   # 指定 Transformer 层
            hidden_cfg=None,
            return_all_layer_outputs=False,
            **kwargs):
        self._self_setattr_tracking = False
        self._hidden_cls = hidden_cls
        self._hidden_cfg = hidden_cfg
        self._num_hidden_instances = num_hidden_instances
        self._pooled_output_dim = pooled_output_dim
        self._pooler_layer_initializer = pooler_layer_initializer
        self._embedding_cls = embedding_cls
        self._embedding_cfg = embedding_cfg
        self._embedding_data = embedding_data
        self._return_all_layer_outputs = return_all_layer_outputs
        self._kwargs = kwargs
        
        # 嵌入层
        if embedding_cls:
            # 指定 嵌入层 类 时，实例化
            if inspect.isclass(embedding_cls):
                self._embedding_network = embedding_cls(
                    **embedding_cfg) if embedding_cfg else embedding_cls()
            # 指定 嵌入层 实例 时，直接使用
            else:
                self._embedding_network = embedding_cls
            inputs = self._embedding_network.inputs
            embeddings, attention_mask = self._embedding_network(inputs)
        else:
            # 没有指定嵌入层，自动创建
            self._embedding_network = None
            word_ids = tf.keras.layers.Input(
                shape=(embedding_cfg['seq_length'],),
                dtype=tf.int32,
                name='input_word_ids')
            mask = tf.keras.layers.Input(
                shape=(embedding_cfg['seq_length'],),
                dtype=tf.int32,
                name='input_mask')
            type_ids = tf.keras.layers.Input(
                shape=(embedding_cfg['seq_length'],),
                dtype=tf.int32,
                name='input_type_ids')
            inputs = [word_ids, mask, type_ids]
            
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=embedding_cfg['initializer'],
                name='word_embeddings')
            
            word_embeddings = self._embedding_layer(word_ids)
            
            # Always uses dynamic slicing for simplicity.
            self._position_embedding_layer = layers.PositionEmbedding(
                initializer=embedding_cfg['initializer'],
                use_dynamic_slicing=True,
                max_sequence_length=embedding_cfg['max_seq_length'],
                name='position_embedding')
            position_embeddings = self._position_embedding_layer(word_embeddings)
            
            type_embeddings = (
                layers.OnDeviceEmbedding(
                    vocab_size=embedding_cfg['type_vocab_size'],
                    embedding_width=embedding_cfg['hidden_size'],
                    initializer=embedding_cfg['initializer'],
                    use_one_hot=True,
                    name='type_embeddings')(type_ids))
            
            embeddings = tf.keras.layers.Add()(
                [word_embeddings, position_embeddings, type_embeddings])
            embeddings = (
                tf.keras.layers.LayerNormalization(
                    name='embeddings/layer_norm',
                    axis=-1,
                    epsilon=1e-12,
                    dtype=tf.float32)(embeddings))
            embeddings = (
                tf.keras.layers.Dropout(
                    rate=embedding_cfg['dropout_rate'])(embeddings))
            
            attention_mask = layers.SelfAttentionMask()([embeddings, mask])
        
        data = embeddings
        
        layer_output_data = []
        self._hidden_layers = []
        
        # 作为参数提供的 Transfomer 层
        for _ in range(num_hidden_instances):
            # 提供层类时，实例化
            if inspect.isclass(hidden_cls):
                layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
            # 提供层实例时，直接使用
            else:
                layer = hidden_cls
            data = layer([data, attention_mask])
            layer_output_data.append(data)
            self._hidden_layers.append(layer)
        
        first_token_tensor = (
            tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
                layer_output_data[-1]))
        self._pooler_layer = tf.keras.layers.Dense(
            units=pooled_output_dim,
            activation='tanh',
            kernel_initializer=pooler_layer_initializer,
            name='cls_transform')
        cls_output = self._pooler_layer(first_token_tensor)
        
        if return_all_layer_outputs:
            outputs = [layer_output_data, cls_output]
        else:
            outputs = [layer_output_data[-1], cls_output]
        
        super(EncoderScaffold, self).__init__(
            inputs=inputs, outputs=outputs, **kwargs)
    
    def get_config(self):
        config_dict = {
            'num_hidden_instances':
                self._num_hidden_instances,
            'pooled_output_dim':
                self._pooled_output_dim,
            'pooler_layer_initializer':
                self._pooler_layer_initializer,
            'embedding_cls':
                self._embedding_network,
            'embedding_cfg':
                self._embedding_cfg,
            'hidden_cfg':
                self._hidden_cfg,
            'return_all_layer_outputs':
                self._return_all_layer_outputs,
        }
        if inspect.isclass(self._hidden_cls):
            config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
                self._hidden_cls)
        else:
            config_dict['hidden_cls'] = self._hidden_cls
        
        config_dict.update(self._kwargs)
        return config_dict
    
    @classmethod
    def from_config(cls, config, custom_objects=None):
        if 'hidden_cls_string' in config:
            config['hidden_cls'] = tf.keras.utils.get_registered_object(
                config['hidden_cls_string'], custom_objects=custom_objects)
            del config['hidden_cls_string']
        return cls(**config)
    
    def get_embedding_table(self):
        if self._embedding_network is None:
            # In this case, we don't have a custom embedding network and can return
            # the standard embedding data.
            return self._embedding_layer.embeddings
        
        if self._embedding_data is None:
            raise RuntimeError(('The EncoderScaffold %s does not have a reference '
                                'to the embedding data. This is required when you '
                                'pass a custom embedding network to the scaffold. '
                                'It is also possible that you are trying to get '
                                'embedding data from an embedding scaffold with a '
                                'custom embedding network where the scaffold has '
                                'been serialized and deserialized. Unfortunately, '
                                'accessing custom embedding references after '
                                'serialization is not yet supported.') % self.name)
        else:
            return self._embedding_data
    
    @property
    def hidden_layers(self):
        """List of hidden layers in the encoder."""
        return self._hidden_layers
    
    @property
    def pooler_layer(self):
        """The pooler dense layer after the transformer layers."""
        return self._pooler_layer